""" This file contains functions that compute the surplus of a match.

"""
import numpy as np
from numba import njit
from dask.distributed import Client, progress
from scipy.optimize import root as root
from src.models_new import states 

try:
    from src.interpolation.splines import eval_linear, UCGrid, nodes
except:
    from ..interpolation.splines import eval_linear, UCGrid, nodes


# SURPLUS FUNCTIONS ---------------------------------------------------------------------
@njit
def get_surplus_once(
        mri, tau, spec, state_initial, grid, values, max_sim_length,
        prob_extend, const, r, sigma, m_0, m_1, m_2, rho_0, rho_1, rho_2, rho_3,
        beta, seed):

    np.random.seed(seed)
    state = state_initial
    contract_periods = tau
    match_value_list = []
    match_value_list_0 = []
    match_value_list_1 = []
    match_value_list_2 = []
    match_value_list_3 = []
    match_value_list_4 = []
    match_value_list_5 = []
    match_value_list_6 = []

    t = 0
    prob_in_match = 1.0

    # Do contract simulation
    while (contract_periods > 0) & (t <= max_sim_length):
        # Get the beta * V (outside option) component
        if t == 1:
            outside_option = eval_linear(grid, values, state)

        # Update contract
        match_value_list_0.append(prob_in_match * beta**t)
        match_value_list_1.append(prob_in_match * beta**t * mri)
        match_value_list_2.append(prob_in_match * beta**t * state[0])
        match_value_list_3.append(prob_in_match * beta**t * state[0] * mri)
        match_value_list_4.append(prob_in_match * beta**t * state[0] * mri ** 2)
        match_value_list_5.append(prob_in_match * beta**t * state[0] * mri ** 3)

        match_value = (m_0 + m_1 * mri + m_2 * state[0] * (rho_0 + rho_1 * mri + rho_2 * mri ** 2 + rho_3 * mri ** 3))
        match_value_list.append(prob_in_match * beta**t * match_value)

        state = states.next_state(state, const, r, sigma)

        contract_periods = contract_periods - 1
        t = t + 1

        # Choose to extend or not
        if contract_periods == 0:
            # NOTE: compared to the above, now in the NEXT state
            prob_extend_t = prob_extend(state, tau, mri)
            #print('PROB EXTEND', prob_extend_t)
            prob_exit_match_at_t = prob_in_match * (1.0 - prob_extend_t)

            match_values_t = prob_exit_match_at_t * beta ** t * eval_linear(grid, values, state)
            match_value_list.append(match_values_t)
            match_value_list_6.append(match_values_t)
            prob_in_match = prob_in_match * prob_extend_t
            contract_periods = tau

    match_values = np.array(match_value_list).sum()
    match_value_0 = np.array(match_value_list_0).sum()
    match_value_1 = np.array(match_value_list_1).sum()
    match_value_2 = np.array(match_value_list_2).sum()
    match_value_3 = np.array(match_value_list_3).sum()
    match_value_4 = np.array(match_value_list_4).sum()
    match_value_5 = np.array(match_value_list_5).sum()
    match_value_6 = np.array(match_value_list_6).sum()

    surplus = match_values - beta * outside_option

    return (surplus, match_values, outside_option, match_value_list,
            match_value_0, match_value_1, match_value_2, match_value_3,
            match_value_4, match_value_5, match_value_6)


#@njit
def get_surplus_all(mri_plus_state, tau, spec, grid, values, prob_extend, const, r, sigma,
                    seeds, max_sim_length, m_0, m_1, m_2, rho_0, rho_1, rho_2, rho_3, beta,
                    verbose=True, verbose_components=False):
    mri = mri_plus_state[0]
    state = mri_plus_state[1:5]

    surplus_by_sim = list()
    match_values_by_sim = list()
    match_value_0_by_sim = []
    match_value_1_by_sim = []
    match_value_2_by_sim = []
    match_value_3_by_sim = []
    match_value_4_by_sim = []
    match_value_5_by_sim = []
    match_value_6_by_sim = []

    outside_option_by_sim = list()

    for seed in seeds:
        output = get_surplus_once(
            mri=mri,
            tau=tau,
            spec=spec,
            state_initial=state,
            max_sim_length=max_sim_length,
            grid=grid,
            values=values,
            prob_extend=prob_extend,
            const=const,
            r=r,
            sigma=sigma,
            m_0=m_0,
            m_1=m_1,
            m_2=m_2,
            rho_0=rho_0,
            rho_1=rho_1,
            rho_2=rho_2,
            rho_3=rho_3,
            beta=beta,
            seed=seed
        )
        surplus_by_sim.append(output[0])
        match_values_by_sim.append(output[1])
        outside_option_by_sim.append(output[2])
        match_value_0_by_sim.append(output[4])
        match_value_1_by_sim.append(output[5])
        match_value_2_by_sim.append(output[6])
        match_value_3_by_sim.append(output[7])
        match_value_4_by_sim.append(output[8])
        match_value_5_by_sim.append(output[9])
        match_value_6_by_sim.append(output[10])

    surplus = np.array(surplus_by_sim).mean()
    outside_option = np.array(outside_option_by_sim).mean()
    match_value_0 = np.array(match_value_0_by_sim).mean()
    match_value_1 = np.array(match_value_1_by_sim).mean()
    match_value_2 = np.array(match_value_2_by_sim).mean()
    match_value_3 = np.array(match_value_3_by_sim).mean()
    match_value_4 = np.array(match_value_4_by_sim).mean()
    match_value_5 = np.array(match_value_5_by_sim).mean()
    match_value_6 = np.array(match_value_6_by_sim).mean()

    if verbose_components:
        return (surplus, surplus_by_sim, match_value_4_by_sim, outside_option_by_sim)
    elif verbose:
        return (surplus, match_value_0, match_value_1, match_value_2,
            match_value_3, match_value_4, match_value_5, match_value_6, outside_option)
    else:
        return surplus


# COMPUTE SURPLUS FAST ------------------------------------------------------------------
def get_surplus_array(mri, state, surplus_grid, surplus_values_by_tau_spec, delta,
                      well_outside_option_by_tau_spec=None, well_target=True):
    """ The surplus array is at the level.

    NOTE: this accounts for the detla

    Args:
        mri:
        state:
        surplus_grid:
        surplus_values_by_tau_spec:
        delta:

    Returns:

    """
    mri_with_state = np.column_stack((mri, np.tile(state, (len(mri), 1))))

    surp_array_by_tau = dict()
    for tau in [2, 3, 4]:
        surp_list = list()
        for spec in ['low', 'mid', 'high']:
            surp_spec = eval_linear(
                surplus_grid,
                surplus_values_by_tau_spec[(tau, spec)],
                mri_with_state
            )
            if (well_target & (well_outside_option_by_tau_spec is None)):
                surp_list.append((1-delta) * surp_spec)
            elif ((well_target == False) & (well_outside_option_by_tau_spec is None)):
                surp_list.append(delta * surp_spec)
            else:
                well_outside_option = eval_linear(
                    surplus_grid,
                    well_outside_option_by_tau_spec[(tau, spec)],
                    mri_with_state
                )
                surp_list.append((1-delta) * surp_spec + well_outside_option)

        surp_array_by_tau[tau] = np.array(surp_list).T

    return surp_array_by_tau


# FAST WAY TO COMPUTE SURPLUS -----------------------------------------------------------
def init_fast_surplus(mri_grid, g_grid, n_grid, tau, spec, grid, values, prob_extend,
                      const, r, sigma, seeds, max_sim_length, beta, options):
    """ Get the components to compute the surplus really fast"""
    # Setup the surplus grid etc
    grid_surplus = UCGrid(
        (-2, 4, mri_grid),
        (2, 15, g_grid),
        (2, 35, n_grid),
        (2, 35, n_grid),
        (2, 35, n_grid)
    )

    nodes_grid = nodes(grid_surplus)
    nodes_list = [nodes_grid[i] for i in range(len(nodes_grid))]

    match_value_0 = list()
    match_value_1 = list()
    match_value_2 = list()
    match_value_3 = list()
    match_value_4 = list()
    match_value_5 = list()
    match_value_6 = list()

    outside_option = list()

    #values = dict()

    # do the computation (in parallel)
    with Client(
        threads_per_worker=options['threads_per_worker'],
        n_workers=options['n_workers']
        ) as client:
        futures = client.map(
            get_surplus_all,
            nodes_list,
            tau=tau,
            spec=spec,
            grid=grid,
            values=values,
            prob_extend=prob_extend,
            const=const,
            r=r,
            sigma=sigma,
            seeds=seeds,
            max_sim_length=max_sim_length,
            m_0=0,
            m_1=0,
            m_2=0,
            rho_0=0,
            rho_1=0,
            rho_2=0,
            rho_3=0,
            beta=beta
        )
        progress(futures)
        for n, i in enumerate(nodes_list):
            output = futures[n].result()
            match_value_0.append(output[1])
            match_value_1.append(output[2])
            match_value_2.append(output[3])
            match_value_3.append(output[4])
            match_value_4.append(output[5])
            match_value_5.append(output[6])
            match_value_6.append(output[7])

            outside_option.append(output[8])

    match_values_all = np.array([
        match_value_0,
        match_value_1,
        match_value_2,
        match_value_3,
        match_value_4,
        match_value_5,
        match_value_6,
        outside_option
    ]).T
    return (grid_surplus, nodes_grid, nodes_list, match_values_all)


def build_fast_surplus(match_values_by_tau_spec, params, surplus_grid,
                       non_myopic_dict=None, ignore_dynamics=False, mri_by_node=None):
    """ Given the surplus building components, get the values"""
    surplus_values_by_tau_spec = dict()
    well_outside_option_by_tau_spec = dict()

    if (non_myopic_dict is None) & (not ignore_dynamics):
        for tau, spec in match_values_by_tau_spec:
            surplus_values_by_tau_spec[(tau, spec)] = (
                match_values_by_tau_spec[(tau, spec)][:, 0] * params[f'm_0_{spec}']
                + match_values_by_tau_spec[(tau, spec)][:, 1] * params[f'm_1_{spec}']  # * mri_by_node
                + match_values_by_tau_spec[(tau, spec)][:, 2] * params[f'm_2'] * params['rho_0']
                + match_values_by_tau_spec[(tau, spec)][:, 3] * params[f'm_2'] * params['rho_1']  # * mri_by_node
                + match_values_by_tau_spec[(tau, spec)][:, 4] * params[f'm_2'] * params['rho_2']  # * mri_by_node
                + match_values_by_tau_spec[(tau, spec)][:, 5] * params[f'm_2'] * params['rho_3']  # * mri_by_node
                + match_values_by_tau_spec[(tau, spec)][:, 6]
                - params['beta'] * match_values_by_tau_spec[(tau, spec)][:, 7]
            )
    elif ignore_dynamics:
        # This is more for testing/prototyping
        for tau, spec in match_values_by_tau_spec:
            surplus_values_by_tau_spec[(tau, spec)] = (
                params[f'm_0_{spec}']
                + mri_by_node * params[f'm_1_{spec}']
                + mri_by_node * params[f'm_2'] * params['rho_0']
                + mri_by_node * params[f'm_2'] * params['rho_1']
                + mri_by_node * params[f'm_2'] * params['rho_2']
                + mri_by_node * params[f'm_2'] * params['rho_3']
            )
    else:
        for tau, spec in match_values_by_tau_spec:
            adjustment = params['beta'] * (1 - non_myopic_dict['prob_exit']) * non_myopic_dict['prob_match'][spec]
            well_outside_option_by_tau_spec[(tau, spec)] = (
                match_values_by_tau_spec[(tau, spec)][:, 0] * adjustment * params[f'm_0_{spec}']
                + match_values_by_tau_spec[(tau, spec)][:, 1] * adjustment * params[f'm_1_{spec}']  # * mri_by_node
                + match_values_by_tau_spec[(tau, spec)][:, 2] * adjustment * params[f'm_2'] * params['rho_0']
                + match_values_by_tau_spec[(tau, spec)][:, 3] * adjustment * params[f'm_2'] * params['rho_1']  # * mri_by_node
                + match_values_by_tau_spec[(tau, spec)][:, 4] * adjustment * params[f'm_2'] * params['rho_2']  # * mri_by_node
                + match_values_by_tau_spec[(tau, spec)][:, 5] * adjustment * params[f'm_2'] * params['rho_3']  # * mri_by_node
            )
            surplus_values_by_tau_spec[(tau, spec)] = (
                match_values_by_tau_spec[(tau, spec)][:, 0] * params[f'm_0_{spec}']
                + match_values_by_tau_spec[(tau, spec)][:, 1] * params[f'm_1_{spec}']  # * mri_by_node
                + match_values_by_tau_spec[(tau, spec)][:, 2] * params[f'm_2'] * params['rho_0']
                + match_values_by_tau_spec[(tau, spec)][:, 3] * params[f'm_2'] * params['rho_1']  # * mri_by_node
                + match_values_by_tau_spec[(tau, spec)][:, 4] * params[f'm_2'] * params['rho_2']  # * mri_by_node
                + match_values_by_tau_spec[(tau, spec)][:, 5] * params[f'm_2'] * params['rho_3']  # * mri_by_node
                + match_values_by_tau_spec[(tau, spec)][:, 6]
                - well_outside_option_by_tau_spec[(tau, spec)]
                - params['beta'] * match_values_by_tau_spec[(tau, spec)][:, 7]
            )

    for tau, spec in match_values_by_tau_spec:
        surplus_values_by_tau_spec[(tau, spec)] = (
            surplus_values_by_tau_spec[(tau, spec)]
            .reshape(
                (surplus_grid[0][2],
                 surplus_grid[1][2],
                 surplus_grid[2][2],
                 surplus_grid[3][2],
                 surplus_grid[4][2])
            )
        )
        if non_myopic_dict is not None:
            well_outside_option_by_tau_spec[(tau, spec)] = (
                well_outside_option_by_tau_spec[(tau, spec)]
                .reshape(
                    (surplus_grid[0][2],
                     surplus_grid[1][2],
                     surplus_grid[2][2],
                     surplus_grid[3][2],
                     surplus_grid[4][2])
                )
            )
    if non_myopic_dict is None:
        return surplus_values_by_tau_spec, None
    else:
        return surplus_values_by_tau_spec, well_outside_option_by_tau_spec


# CUTOFF FUNCTIONS ----------------------------------------------------------------------
def surplus_wrapper_for_cutoffs(x, state, surplus_grid, surplus_values):
    mri_with_state = np.array([x[0], state[0], state[1], state[2], state[3]])
    return eval_linear(surplus_grid, surplus_values, mri_with_state)


def cutoffs(state, surplus_grid, surplus_values):
    """  Finds the cutoff given the current state.

    Args:

    Returns:
        cutoff: the mri that is the cutoff between accepting and
            not accepting.

    """
    # Find the root of the cutoff
    x0 = 0.75
    args = (state, surplus_grid, surplus_values)
    x_cutoff = root(surplus_wrapper_for_cutoffs, x0=x0, args=args)
    #print(x_cutoff)
    return x_cutoff


def get_cutoffs(state, surplus_grid, surplus_values_by_spec, mri_max, mri_state_grid, value_zero=False):
    """ Get the cutoffs.

    Args:
        state: The current state.

    Returns:
        y_min: This is the minimum mri. Given as a dict where the
            keys are in the form (spec, tau).
        y_max: This is the maximum mri. Given as a dict where the
            keys are in the form (spec, tau).
    """
    # Get the minimums
    y_min = dict()
    for tau in [2, 3, 4]:
        y_min[('low', tau)] = 0.0
        y_min[('mid', tau)] = cutoffs(
            state, surplus_grid, surplus_values_by_spec[(tau, 'mid')]
        ).x[0]
        y_min[('high', tau)] = cutoffs(
            state, surplus_grid, surplus_values_by_spec[(tau, 'high')]
        ).x[0]

        y_min[('mid', tau)] = min(max(y_min[('mid', tau)], 0.0), 2.0)
        y_min[('high', tau)] = min(max(y_min[('high', tau)], 0.0), 2.0)

    # Get the maximums
    y_max = dict()
    for tau in [2, 3, 4]:
        y_max[('low', tau)] = cutoffs(
            state, surplus_grid, surplus_values_by_spec[(tau, 'low')]
        ).x[0]
        y_max[('low', tau)] = min(max(y_max[('low', tau)], 0.1), mri_max) #2.144)

        y_max[('mid', tau)] = mri_max #2.38
        y_max[('high', tau)] = mri_max #3.1

    if value_zero is True:
        for tau in [2, 3, 4]:
            y_max[('low', tau)] = mri_max #(y_max[('low', tau)] + 2.15) / 2
            y_min[('mid', tau)] = 0
            y_min[('high', tau)] = 0 #y_min[('high', tau)] / 2

    # Finally, get the mask
    cutoff_mask_by_spec = dict()
    for spec in ['low', 'mid', 'high']:
        # Setup cutoffs
        if spec == 'low':
            cutoffs_at_spec = [
                y_max[('low', 2)],
                y_max[('low', 3)],
                y_max[('low', 4)]
            ]
        if spec == 'mid':
            cutoffs_at_spec = [
                y_min[('mid', 2)],
                y_min[('mid', 3)],
                y_min[('mid', 4)]
            ]
        if spec == 'high':
            cutoffs_at_spec = [
                y_min[('high', 2)],
                y_min[('high', 3)],
                y_min[('high', 4)]
            ]

        # Build cutoff mask for the mri grid
        cutoff_mask = list()
        for i in [0, 1, 2]:
            if spec == 'low':
                cutoff_mask.append(mri_state_grid <= cutoffs_at_spec[i])
            elif spec == 'mid':
                cutoff_mask.append(mri_state_grid >= cutoffs_at_spec[i])
            elif spec == 'high':
                cutoff_mask.append(mri_state_grid >= cutoffs_at_spec[i])

        cutoff_mask_by_spec[spec] = cutoff_mask

    return y_min, y_max, cutoff_mask_by_spec


#%% ADD IN CUTOFF COMPUTATION (ACTUAL) --------------------------------------------------
def surplus_wrapper_for_cutoffs_actual(x, *args):
    x = x[0]
    return get_surplus_all(x, *args)


def cutoffs_actual(tau, spec, state, grid, values, prob_extend, const, r, sigma, seeds,
            max_sim_length, m_0, m_1, m_2, rho_0, rho_1, beta):
    """  Finds the cutoff given the current state.

    Args:

    Returns:
        cutoff: the mri that is the cutoff between accepting and
            not accepting.

    """
    # Find the root of the cutoff
    x0 = 0.75
    args = (
        tau, spec, state, grid, values, prob_extend, const, r, sigma,
        seeds, max_sim_length, m_0, m_1, m_2, rho_0, rho_1, beta, False
    )
    x_cutoff = root(surplus_wrapper_for_cutoffs_actual, x0=x0, args=args)

    return x_cutoff


def get_cutoffs_actual(state, grid, values_low, values_mid, values_high, prob_extend,
                const, r, sigma, seeds, max_sim_length, m_0_low, m_1_low, m_2_low,
                m_0_high, m_1_high, m_2_high, rho_0, rho_1, beta, value_zero=False):

    # Get the minimums
    y_min = dict()
    for tau in [2, 3, 4]:
        y_min[('low', tau)] = 0.0
        y_min[('mid', tau)] = 0.0
        y_min[('high', tau)] = cutoffs_actual(
            tau, 'high', state, grid, values_high,
            prob_extend, const, r, sigma, seeds, max_sim_length,
            m_0_low, m_1_low, m_2_low, rho_0, rho_1, beta
        ).x[0]
        y_min[('high', tau)] = min(max(y_min[('high', tau)], 0.0), 0.8)

    # Get the maximums
    y_max = dict()
    for tau in [2, 3, 4]:
        y_max[('low', tau)] = cutoffs_actual(
            tau, 'low', state, grid, values_low,
            prob_extend, const, r, sigma, seeds, max_sim_length,
            m_0_high, m_1_high, m_2_high, rho_0, rho_1, beta
        ).x[0]
        y_max[('low', tau)] = min(max(y_max[('low', tau)], 0.8), 2.15) #2.144)
        y_max[('mid', tau)] = 2.15 #2.38
        y_max[('high', tau)] = 2.15 #3.1

    return y_min, y_max
